{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.data.all import *\n",
    "from fastai.text.core import *\n",
    "from fastai.text.models.awdlstm import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp text.models.core\n",
    "#default_cls_lvl 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Core text modules\n",
    "\n",
    "> Contain the modules common between different architectures and the generic functions to get models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export \n",
    "_model_meta = {AWD_LSTM: {'hid_name':'emb_sz', 'url':URLs.WT103_FWD, 'url_bwd':URLs.WT103_BWD,\n",
    "                          'config_lm':awd_lstm_lm_config, 'split_lm': awd_lstm_lm_split,\n",
    "                          'config_clas':awd_lstm_clas_config, 'split_clas': awd_lstm_clas_split},}\n",
    "              # Transformer: {'hid_name':'d_model', 'url':URLs.OPENAI_TRANSFORMER,\n",
    "              #               'config_lm':tfmer_lm_config, 'split_lm': tfmer_lm_split,\n",
    "              #               'config_clas':tfmer_clas_config, 'split_clas': tfmer_clas_split},\n",
    "              # TransformerXL: {'hid_name':'d_model',\n",
    "              #                'config_lm':tfmerXL_lm_config, 'split_lm': tfmerXL_lm_split,\n",
    "              #                'config_clas':tfmerXL_clas_config, 'split_clas': tfmerXL_clas_split}}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Language models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class LinearDecoder(Module):\n",
    "    \"To go on top of a RNNCore module and create a Language Model.\"\n",
    "    initrange=0.1\n",
    "\n",
    "    def __init__(self, n_out, n_hid, output_p=0.1, tie_encoder=None, bias=True):\n",
    "        self.decoder = nn.Linear(n_hid, n_out, bias=bias)\n",
    "        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)\n",
    "        self.output_dp = RNNDropout(output_p)\n",
    "        if bias: self.decoder.bias.data.zero_()\n",
    "        if tie_encoder: self.decoder.weight = tie_encoder.weight\n",
    "\n",
    "    def forward(self, input):\n",
    "        dp_inp = self.output_dp(input)\n",
    "        return self.decoder(dp_inp), input, dp_inp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.text.models.awdlstm import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = AWD_LSTM(100, 20, 10, 2)\n",
    "x = torch.randint(0, 100, (10,5))\n",
    "r = enc(x)\n",
    "\n",
    "tst = LinearDecoder(100, 20, 0.1)\n",
    "y = tst(r)\n",
    "test_eq(y[1], r)\n",
    "test_eq(y[2].shape, r.shape)\n",
    "test_eq(y[0].shape, [10, 5, 100])\n",
    "\n",
    "tst = LinearDecoder(100, 20, 0.1, tie_encoder=enc.encoder)\n",
    "test_eq(tst.decoder.weight, enc.encoder.weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class SequentialRNN(nn.Sequential):\n",
    "    \"A sequential module that passes the reset call to its children.\"\n",
    "    def reset(self):\n",
    "        for c in self.children(): getattr(c, 'reset', noop)()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class _TstMod(Module):\n",
    "    def reset(self): print('reset')\n",
    "\n",
    "tst = SequentialRNN(_TstMod(), _TstMod())\n",
    "test_stdout(tst.reset, 'reset\\nreset')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def get_language_model(arch, vocab_sz, config=None, drop_mult=1.):\n",
    "    \"Create a language model from `arch` and its `config`.\"\n",
    "    meta = _model_meta[arch]\n",
    "    config = ifnone(config, meta['config_lm']).copy()\n",
    "    for k in config.keys():\n",
    "        if k.endswith('_p'): config[k] *= drop_mult\n",
    "    tie_weights,output_p,out_bias = map(config.pop, ['tie_weights', 'output_p', 'out_bias'])\n",
    "    init = config.pop('init') if 'init' in config else None\n",
    "    encoder = arch(vocab_sz, **config)\n",
    "    enc = encoder.encoder if tie_weights else None\n",
    "    decoder = LinearDecoder(vocab_sz, config[meta['hid_name']], output_p, tie_encoder=enc, bias=out_bias)\n",
    "    model = SequentialRNN(encoder, decoder)\n",
    "    return model if init is None else model.apply(init)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The default `config` used can be found in `_model_meta[arch]['config_lm']`. `drop_mult` is applied to all the probabilities of dropout in that config."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = awd_lstm_lm_config.copy()\n",
    "config.update({'n_hid':10, 'emb_sz':20})\n",
    "\n",
    "tst = get_language_model(AWD_LSTM, 100, config=config)\n",
    "x = torch.randint(0, 100, (10,5))\n",
    "y = tst(x)\n",
    "test_eq(y[0].shape, [10, 5, 100])\n",
    "test_eq(y[1].shape, [10, 5, 20])\n",
    "test_eq(y[2].shape, [10, 5, 20])\n",
    "test_eq(tst[1].decoder.weight, tst[0].encoder.weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#test drop_mult\n",
    "tst = get_language_model(AWD_LSTM, 100, config=config, drop_mult=0.5)\n",
    "test_eq(tst[1].output_dp.p, config['output_p']*0.5)\n",
    "for rnn in tst[0].rnns: test_eq(rnn.weight_p, config['weight_p']*0.5)\n",
    "for dp in tst[0].hidden_dps: test_eq(dp.p, config['hidden_p']*0.5)\n",
    "test_eq(tst[0].encoder_dp.embed_p, config['embed_p']*0.5)\n",
    "test_eq(tst[0].input_dp.p, config['input_p']*0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classification models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _pad_tensor(t, bs):\n",
    "    if t.size(0) < bs: return torch.cat([t, t.new_zeros(bs-t.size(0), *t.shape[1:])])\n",
    "    return t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class SentenceEncoder(Module):\n",
    "    \"Create an encoder over `module` that can process a full sentence.\"\n",
    "    def __init__(self, bptt, module, pad_idx=1, max_len=None): store_attr('bptt,module,pad_idx,max_len')\n",
    "    def reset(self): getattr(self.module, 'reset', noop)()\n",
    "\n",
    "    def forward(self, input):\n",
    "        bs,sl = input.size()\n",
    "        self.reset()\n",
    "        mask = input == self.pad_idx\n",
    "        outs,masks = [],[]\n",
    "        for i in range(0, sl, self.bptt):\n",
    "            #Note: this expects that sequence really begins on a round multiple of bptt\n",
    "            real_bs = (input[:,i] != self.pad_idx).long().sum()\n",
    "            o = self.module(input[:real_bs,i: min(i+self.bptt, sl)])\n",
    "            if self.max_len is None or sl-i <= self.max_len:\n",
    "                outs.append(o)\n",
    "                masks.append(mask[:,i: min(i+self.bptt, sl)])\n",
    "        outs = torch.cat([_pad_tensor(o, bs) for o in outs], dim=1)\n",
    "        mask = torch.cat(masks, dim=1)\n",
    "        return outs,mask"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Warning: This module expects the inputs padded with most of the padding first, with the sequence beginning at a round multiple of `bptt` (and the rest of the padding at the end). Use `pad_input_chunk` to get your data in a suitable format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mod = nn.Embedding(5, 10)\n",
    "tst = SentenceEncoder(5, mod, pad_idx=0)\n",
    "x = torch.randint(1, 5, (3, 15))\n",
    "x[2,:5]=0\n",
    "out,mask = tst(x)\n",
    "\n",
    "test_eq(out[:1], mod(x)[:1])\n",
    "test_eq(out[2,5:], mod(x)[2,5:])\n",
    "test_eq(mask, x==0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def masked_concat_pool(output, mask, bptt):\n",
    "    \"Pool `MultiBatchEncoder` outputs into one vector [last_hidden, max_pool, avg_pool]\"\n",
    "    lens = output.shape[1] - mask.long().sum(dim=1)\n",
    "    last_lens = mask[:,-bptt:].long().sum(dim=1)\n",
    "    avg_pool = output.masked_fill(mask[:, :, None], 0).sum(dim=1)\n",
    "    avg_pool.div_(lens.type(avg_pool.dtype)[:,None])\n",
    "    max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]\n",
    "    x = torch.cat([output[torch.arange(0, output.size(0)),-last_lens-1], max_pool, avg_pool], 1) #Concat pooling.\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = torch.randn(2,4,5)\n",
    "mask = tensor([[True,True,False,False], [False,False,False,True]])\n",
    "x = masked_concat_pool(out, mask, 2)\n",
    "\n",
    "test_close(x[0,:5], out[0,-1])\n",
    "test_close(x[1,:5], out[1,-2])\n",
    "test_close(x[0,5:10], out[0,2:].max(dim=0)[0])\n",
    "test_close(x[1,5:10], out[1,:3].max(dim=0)[0])\n",
    "test_close(x[0,10:], out[0,2:].mean(dim=0))\n",
    "test_close(x[1,10:], out[1,:3].mean(dim=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Test the result is independent of padding by replacing the padded part by some random content\n",
    "out1 = torch.randn(2,4,5)\n",
    "out1[0,2:] = out[0,2:].clone()\n",
    "out1[1,:3] = out[1,:3].clone()\n",
    "x1 = masked_concat_pool(out1, mask, 2)\n",
    "test_eq(x, x1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class PoolingLinearClassifier(Module):\n",
    "    \"Create a linear classifier with pooling\"\n",
    "    def __init__(self, dims, ps, bptt, y_range=None):\n",
    "        if len(ps) != len(dims)-1: raise ValueError(\"Number of layers and dropout values do not match.\")\n",
    "        acts = [nn.ReLU(inplace=True)] * (len(dims) - 2) + [None]\n",
    "        layers = [LinBnDrop(i, o, p=p, act=a) for i,o,p,a in zip(dims[:-1], dims[1:], ps, acts)]\n",
    "        if y_range is not None: layers.append(SigmoidRange(*y_range))\n",
    "        self.layers = nn.Sequential(*layers)\n",
    "        self.bptt = bptt\n",
    "\n",
    "    def forward(self, input):\n",
    "        out,mask = input\n",
    "        x = masked_concat_pool(out, mask, self.bptt)\n",
    "        x = self.layers(x)\n",
    "        return x, out, out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mod = nn.Embedding(5, 10)\n",
    "tst = SentenceEncoder(5, mod, pad_idx=0)\n",
    "x = torch.randint(1, 5, (3, 15))\n",
    "x[2,:5]=0\n",
    "out,mask = tst(x)\n",
    "\n",
    "test_eq(out[:1], mod(x)[:1])\n",
    "test_eq(out[2,5:], mod(x)[2,5:])\n",
    "test_eq(mask, x==0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "mod = nn.Embedding(5, 10)\n",
    "tst = nn.Sequential(SentenceEncoder(5, mod, pad_idx=0), PoolingLinearClassifier([10*3,4], [0.], 5))\n",
    "\n",
    "x = torch.randint(1, 5, (3, 14))\n",
    "x[2,:5] = 0\n",
    "res,raw,out = tst(x) \n",
    "\n",
    "test_eq(raw[:1], mod(x)[:1])\n",
    "test_eq(raw[2,5:], mod(x)[2,5:])\n",
    "test_eq(out[:1], mod(x)[:1])\n",
    "test_eq(out[2,5:], mod(x)[2,5:])\n",
    "test_eq(res.shape, [3,4])\n",
    "\n",
    "x1 = torch.cat([x, tensor([0,0,0])[:,None]], dim=1)\n",
    "res1,raw1,out1 = tst(x1) \n",
    "test_eq(res, res1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def get_text_classifier(arch, vocab_sz, n_class, seq_len=72, config=None, drop_mult=1., lin_ftrs=None,\n",
    "                        ps=None, pad_idx=1, max_len=72*20, y_range=None):\n",
    "    \"Create a text classifier from `arch` and its `config`, maybe `pretrained`\"\n",
    "    meta = _model_meta[arch]\n",
    "    config = ifnone(config, meta['config_clas']).copy()\n",
    "    for k in config.keys():\n",
    "        if k.endswith('_p'): config[k] *= drop_mult\n",
    "    if lin_ftrs is None: lin_ftrs = [50]\n",
    "    if ps is None:  ps = [0.1]*len(lin_ftrs)\n",
    "    layers = [config[meta['hid_name']] * 3] + lin_ftrs + [n_class]\n",
    "    ps = [config.pop('output_p')] + ps\n",
    "    init = config.pop('init') if 'init' in config else None\n",
    "    encoder = SentenceEncoder(seq_len, arch(vocab_sz, **config), pad_idx=pad_idx, max_len=max_len)\n",
    "    model = SequentialRNN(encoder, PoolingLinearClassifier(layers, ps, bptt=seq_len, y_range=y_range))\n",
    "    return model if init is None else model.apply(init)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = awd_lstm_clas_config.copy()\n",
    "config.update({'n_hid':10, 'emb_sz':20})\n",
    "\n",
    "tst = get_text_classifier(AWD_LSTM, 100, 3, config=config)\n",
    "x = torch.randint(2, 100, (10,5))\n",
    "y = tst(x)\n",
    "test_eq(y[0].shape, [10, 3])\n",
    "test_eq(y[1].shape, [10, 5, 20])\n",
    "test_eq(y[2].shape, [10, 5, 20])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#test padding gives same results\n",
    "tst.eval()\n",
    "y = tst(x)\n",
    "x1 = torch.cat([x, tensor([2,1,1,1,1,1,1,1,1,1])[:,None]], dim=1)\n",
    "y1 = tst(x1)\n",
    "test_close(y[0][1:],y1[0][1:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#test drop_mult\n",
    "tst = get_text_classifier(AWD_LSTM, 100, 3, config=config, drop_mult=0.5)\n",
    "test_eq(tst[1].layers[1][1].p, 0.1)\n",
    "test_eq(tst[1].layers[0][1].p, config['output_p']*0.5)\n",
    "for rnn in tst[0].module.rnns: test_eq(rnn.weight_p, config['weight_p']*0.5)\n",
    "for dp in tst[0].module.hidden_dps: test_eq(dp.p, config['hidden_p']*0.5)\n",
    "test_eq(tst[0].module.encoder_dp.embed_p, config['embed_p']*0.5)\n",
    "test_eq(tst[0].module.input_dp.p, config['input_p']*0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.image_sequence.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.azureml.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
